from predefined_tasks.pmt_adjust import *


async def pmt_instrument_discriminator_scan(channel='uslum', dl_adjust=0.380, dl_start=0.0, dl_stop=1.0, dl_step=0.001, window_ms=100, iterations=10, hv_enable=0, hts_alpha_enable=0, report_file=''):

    GlobalVar.set_stop_gc(False)

    window_count = 1

    if channel == 'pmt1':
        current_hv_enable = (await meas_endpoint.GetParameter(MeasurementParameter.PMT1HighVoltageEnable))[0]
        current_dl = (await meas_endpoint.GetParameter(MeasurementParameter.PMT1DiscriminatorLevel))[0]
    if channel == 'pmt2':
        current_hv_enable = (await meas_endpoint.GetParameter(MeasurementParameter.PMT2HighVoltageEnable))[0]
        current_dl = (await meas_endpoint.GetParameter(MeasurementParameter.PMT2DiscriminatorLevel))[0]
    if channel == 'uslum':
        current_hv_enable = (await meas_endpoint.GetParameter(MeasurementParameter.PMTUSLUMHighVoltageEnable))[0]
        current_dl = (await meas_endpoint.GetParameter(MeasurementParameter.PMTUSLUMDiscriminatorLevel))[0]

    report_file = str(report_file) if (report_file != '') else os.path.join(report_path, f"pmt_instrument_discriminator_scan.csv")

    try:
        with open(report_file, 'a') as report:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            report.write(f"pmt_instrument_discriminator_scan(channel={channel}, dl_adjust={dl_adjust}, dl_start={dl_start:.3f}, dl_stop={dl_stop:.3f}, dl_step={dl_step:.3f}, window_ms={window_ms}, iterations={iterations}, hv_enable={hv_enable}, hts_alpha_enable={hts_alpha_enable}) started at {timestamp}\n")
            report.write(f"temperature: {await pmt_get_temperature(channel)}\n")
            report.write('\n')
            
            dl_range = np.arange(dl_start, (dl_stop + 1e-6), dl_step).round(6)  # The discriminator level scan range

            await send_to_gc(f"dl    ; {channel + '_cps_mean':14} ; {channel + '_cps_max':14}", report=report)

            if hts_alpha_enable:
                await meas_endpoint.SetParameter(MeasurementParameter.HTSAlphaLaserEnable, 1)

            await pmt_set_hv_enable(channel, hv_enable)
            await asyncio.sleep(pmt_set_hv_enable_delay)

            cps_mean = np.zeros_like(dl_range)
            cps_max = np.zeros_like(dl_range)

            for i, dl in enumerate(dl_range):
                await pmt_set_dl(channel, dl)
                await asyncio.sleep(pmt_set_dl_delay)
                results = await pmt_counting_measurement(window_ms, window_count, iterations)
                if GlobalVar.get_stop_gc():
                    return f"pmt_adjust_discriminator stopped by user"

                cps_mean[i] = results[f"{channel}_cps_mean"]
                cps_max[i] = results[f"{channel}_cps_max"]

                await send_to_gc(f"{dl:5.3f} ; {cps_mean[i]:14.0f} ; {cps_max[i]:14.0f}", report=report)

            report.write('\n')

            await plot_instrument_dl_scan(channel, dl_range, cps_mean, cps_max, dl_adjust, hv_enable)

    finally:
        if hts_alpha_enable:
            await meas_endpoint.SetParameter(MeasurementParameter.HTSAlphaLaserEnable, 0)
        await pmt_set_hv_enable(channel, current_hv_enable)
        await pmt_set_dl(channel, current_dl)


async def plot_instrument_dl_scan(channel, dl_range, cps_mean, cps_max, dl_adjust, hv_enable, file_name='graph.png'):

    plt.clf()

    plt.subplot(2, 1, 1)
    if hv_enable:
        plt.title('Discriminator Scan (HV = ON)')
    else:
        plt.title('Discriminator Scan (HV = OFF)')
    plt.ylabel(f"{channel.upper()}_CPS_MEAN")
    plt.yscale('symlog', linthresh=1)
    plt.plot(dl_range, cps_mean)
    plt.axvline(dl_adjust, color='r')

    plt.subplot(2, 1, 2)
    plt.xlabel('Discriminator Level')
    plt.ylabel(f"{channel.upper()}_CPS_MAX")
    plt.yscale('symlog', linthresh=1)
    plt.plot(dl_range, cps_max)
    plt.axvline(dl_adjust, color='r')

    plt.savefig(os.path.join(images_path, file_name))
    await send_gc_event('RefreshGraph', file_name=os.path.join('pmt_adjust_images', file_name))


async def pmt_high_voltage_scan(channel='pmt1', dl=0.26, hv_adjust=0.475, hv_start=0.2, hv_stop=0.73, hv_step=0.005, led_current='0;1', window_ms=1000, iterations=1, temperature=0, shutdown=0, report_file=''):

    GlobalVar.set_stop_gc(False)

    if not isinstance(hv_adjust, str):
        hv_adjust = [hv_adjust]
    else:
        hv_adjust = [float(hv) for hv in hv_adjust.split(';')]

    if not isinstance(led_current, str):
        led_current = [led_current]
    else:
        led_current = [int(current) for current in led_current.split(';')]

    hv_limit = {'pmt1':0.73, 'pmt2':0.73, 'uslum':0.82, 'htsal':0.82}
    if (hv_start > hv_limit[channel]) or (hv_stop > hv_limit[channel]):
        raise ValueError(f"The high voltage setting of {channel} must not exceed {hv_limit[channel]}")
    
    led_source, led_channel, led_type = led_dimmed[channel].split('_')

    report_file = str(report_file) if (report_file != '') else os.path.join(report_path, f"pmt_high_voltage_scan.csv")

    await base_tester_enable(True, nested=False)
    try:
        if temperature > 0 and channel == 'pmt1':
            await pmt1_cooling.InitializeDevice()
            await pmt1_cooling.set_target_temperature(temperature)
            await pmt1_cooling.enable()
        if temperature > 0 and channel == 'pmt2':
            await pmt2_cooling.InitializeDevice()
            await pmt2_cooling.set_target_temperature(temperature)
            await pmt2_cooling.enable()

        await pmt_set_dl(channel, dl)
        await pmt_set_hv(channel, hv_start)
        await asyncio.sleep(pmt_set_hv_delay)
        await pmt_set_hv_enable(channel, 1)
        await asyncio.sleep(pmt_set_hv_enable_delay)

        with open(report_file, 'a') as report:
            timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
            report.write(f"pmt_high_voltage_scan(channel={channel}, dl={dl}, hv_adjust={hv_adjust}, hv_start={hv_start} hv_stop={hv_stop}, hv_step={hv_step}, led_current={led_current}, window_ms={window_ms}, iterations={iterations}, temperature={temperature}, shutdown={shutdown}) started at {timestamp}\n")
            report.write('\n')

            for current in led_current:
                await send_to_gc(f"LED: {current}", log=True, report=report)
                await send_to_gc(f" ")
        
                await set_led_current(current, led_source, led_channel, led_type)

                if temperature > 0:
                    await send_to_gc(f"hv    ; {channel + '_cps':9} ; temperature", log=True, report=report)
                else:
                    await send_to_gc(f"hv    ; {channel + '_cps':9}", log=True, report=report)

                hv_range = np.arange(hv_start, (hv_stop + 1e-6), hv_step).round(6)
                cps = np.zeros_like(hv_range)

                for i, hv in enumerate(hv_range):
                    await pmt_set_hv(channel, hv)
                    await asyncio.sleep(pmt_set_hv_delay)
                    results = await pmt_counting_measurement(window_ms, iterations)
                    if GlobalVar.get_stop_gc():
                        return f"pmt_high_voltage_scan stopped by user"

                    cps[i] = results[f"{channel}_cps_mean"]

                    if temperature > 0:
                        await send_to_gc(f"{hv:.3f} ; {cps[i]:9.0f} ; {await pmt_get_temperature(channel)}", log=True, report=report)
                    else:
                        await send_to_gc(f"{hv:.3f} ; {cps[i]:9.0f}", log=True, report=report)

                report.write('\n')

                await plot_hv_scan(channel, hv_range, cps, hv_adjust, current, file_name=f"pmt_hv_scan_led_{current}.png")
            
        await pmt_set_hv_enable(channel, 0)
        await set_led_current(0, led_source, led_channel, led_type)

    finally:
        if shutdown:
            await base_tester_enable(False, nested=False)

    return f"pmt_high_voltage_scan done"


async def plot_hv_scan(channel, hv_range, cps, hv_adjust, led_current, file_name='graph.png'):

    plt.clf()

    plt.title(f"High Voltage Scan LED={led_current}")
    plt.xlabel('hv')
    plt.ylabel(channel.upper())
    plt.plot(hv_range, cps)
    if hasattr(hv_adjust, '__iter__'):
        for hv in hv_adjust:
            plt.axvline(hv, color='r')
    else:
        plt.axvline(hv_adjust, color='r')

    plt.savefig(os.path.join(images_path, file_name))
    await send_gc_event('RefreshGraph', file_name=os.path.join('pmt_adjust_images', file_name))


async def pmt_pdd_test(window_ms, window_count, iterations):

    await send_to_gc(f"Starting Firmware", log=True)
    await asyncio.gather(
        fmb_unit.StartFirmware(),
        eef_unit.StartFirmware(),
    )

    await set_led_current(2, 'fmb', 'led2', 'green')

    await pmt_set_dl('pmt1', 0.26)
    await pmt_set_hv('pmt1', 0.475)
    await asyncio.sleep(pmt_set_hv_delay)
    await pmt_set_hv_enable('pmt1', 1)
    await asyncio.sleep(pmt_set_hv_enable_delay)
    
    op_id = 'pmt_pdd_test'
    meas_unit.ClearOperations()
    await load_pmt_counting_measurement(op_id, window_ms, window_count)

    await send_to_gc(f" ; pmt         ; pdd", log=True)

    for i in range(iterations):
        if GlobalVar.get_stop_gc():
            return f"pmt_pdd_test stopped by user"
        
        await meas_unit.ExecuteMeasurement(op_id)
        results = await meas_unit.ReadMeasurementValues(op_id)

        pmt1_cps = (results[0]  + (results[1]  << 32)) / window_count / window_ms * 1000.0
        pmt1_pdd = (results[16] + results[17] * config['pmt1_pdd_scaling']) / window_count / window_ms * 1000.0

        await send_to_gc(f" ; {pmt1_cps:11.1f} ; {pmt1_pdd:11.1f}", log=True)

    return f"pmt_pdd_test done"


async def usfm_to_max():
    await mc6_unit.StartFirmware()
    await usfm_unit.InitializeDevice()
    await usfm_unit.Home()
    await usfm_unit.GotoPosition(focus_mover_config.Positions.Max)


async def pmt_adjust_enable(channel='pmt1;pmt2;uslum;htsal', enable=1):
    if enable:
        await base_tester_enable(True, nested=False)
        if 'pmt1' in channel:
            await pmt1_cooling.InitializeDevice()
            await pmt1_cooling.set_target_temperature(18.0)
            await pmt1_cooling.enable()
        if 'pmt2' in channel:
            await pmt2_cooling.InitializeDevice()
            await pmt2_cooling.set_target_temperature(18.0)
            await pmt2_cooling.enable()
        if 'htsal' in channel:
            await hts_alpha_cooling.InitializeDevice()
            await hts_alpha_cooling.set_target_temperature(-273.15)
            await hts_alpha_cooling.disable()
        if ('uslum' in channel) or ('htsal' in channel):
            await uslum_fan.InitializeDevice()
            await uslum_fan.enable()
            await usfm_unit.InitializeDevice()
            await usfm_unit.Home()
            await usfm_unit.GotoPosition(focus_mover_config.Positions.Max)
    else:
        if 'htsal' in channel:
            await hts_alpha_cooling.disable()
        if ('uslum' in channel) or ('htsal' in channel):
            await usfm_unit.Move(0)
        await base_tester_enable(False, nested=False)

    return 'Adjustment Stage Enabled' if enable else 'Adjustment Stage Disabled'


async def create_signal_scan_plot():
    ch = 'pmt1'
    hv = 0.440
    ppr_ns = 15.82

    signal_ref = {ch:np.array([2962,6887,11319,15687,20491,30246,44332,59484,84794,114697,161701,221803,310688,429420,597873,1038210,1714382,2386053,3122131,4618090,6799621,9146170,12996078,18490861,25053757,35354850,49256024,69237444,96884176])}
    signal_cps = {ch:np.array([13491,31380,52064,72133,94458,138398,203697,273655,390398,529550,746195,1022530,1433880,1974995,2715893,4586153,7272287,10196814,12105701,16332877,21318055,25447924,30545902,35592328,39829119,44414874,48439404,52018001,55201301])}

    signal_cps[ch] = pmt_calculate_correction(signal_cps[ch], ppr_ns)
    signal_ref[ch] = np.mean(signal_cps[ch][:15] / signal_ref[ch][:15]) * signal_ref[ch]

    await plot_signal_scan([ch], signal_ref, signal_cps, signal_ref, hv, file_name=f"signal_scan_hv_{hv:.3f}.png")

